import random
import itertools

import numpy as np

import gymnasium as gym
from gymnasium import spaces
from gymnasium.utils import seeding


######################################################
# Block-switch bandit
class BlockBanditEnv(gym.Env):
    """
    Bandit environment base with change of reward in block structure, time as state.
    ---
    p_bandits:
        A list of reward probabilities for each bandit
    r_bandits:
        A list of either rewards (if number) or means and standard deviations (if list)
        of the payout for each bandit
    info:
        Info about the environment that the agents is not supposed to know. 
        For instance, info can reveal the index of the optimal arm, 
        or the value of prior parameter.
        Can be useful to evaluate the agent's perfomance
    """
    def __init__(self, 
        block_lens, block_p_bandits, block_r_bandits
    ):

        self.block_lens = block_lens
        self.block_p_bandits = block_p_bandits
        self.block_r_bandits = block_r_bandits
        self.max_episode_steps = np.sum(self.block_lens)

        # state space
        self.observation_space = spaces.Dict({
            "timestep": spaces.Box(low=0, high=self.max_episode_steps, dtype=np.int64),  # timestep
        })
        self.timestep = np.array([0])
        self.current_block = 0
        
        # action space
        self.k_bandits = len(self.block_p_bandits[0])
        self.action_space = spaces.Discrete(self.k_bandits)

    def _seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def _get_obs(self):
        return {"timestep": self.timestep}
    
    def _get_info(self):
        return {
            "block_lens": self.block_lens,
            "reward_prob": self.block_p_bandits,
            "reward_size": self.block_r_bandits,
            "timestep": self.timestep,
            "current_block": self.current_block,
            "current_reward_prob": self.block_p_bandits[self.current_block]
        }

    def reset(self, seed=None, options={}):
        """
        The reset method will be called to initiate a new episode. 
        You may assume that the step method will not be called before reset has been called. 
        Moreover, reset should be called whenever a done signal has been issued.
        This should *NOT* automatically reset the task! Resetting the task is 
        handled in the wrapper.
        """
        # seed self.np_random
        # pass an integer for RHG right after the environment has been initialized 
        # and then never again
        super().reset(seed=seed)

        # reset timestep
        self.timestep = np.array([0])
        self.current_block = 0
        
        observation = self._get_obs()
        info = self._get_info()

        return observation, info

    def step(self, action):
        """
        Execute one step in the environment.
        Should return: (observation, reward, terminated, truncated, info)
        If terminated or truncated is true, the user needs to call reset().
        """
        # action should be type integer in [0, k_bandits-1]
        assert self.action_space.contains(action)

        #print(f'step: {self.step_count}, block {self.current_block}')
        
        # state transition
        self.timestep += 1

        # An episode is done iff max_episode_steps is reached
        terminated = bool((self.timestep >= self.max_episode_steps))

        # compute reward
        reward = 0
        p_bandits = self.block_p_bandits[self.current_block]
        r_bandits = self.block_r_bandits[self.current_block]
        if random.uniform(0, 1) < p_bandits[action]:
            if not isinstance(r_bandits[action], list):
                reward = r_bandits[action]
            else:
                reward = random.gauss(r_bandits[action][0], r_bandits[action][1])

        # update current block
        if not terminated:
            if self.timestep >= np.array(self.block_lens[:self.current_block+1]).sum():
                self.current_block += 1
        
        observation = self._get_obs()
        info = self._get_info()
        
        return observation, reward, terminated, False, info


class BlockBandit2ArmCoupledEasy(BlockBanditEnv):
    """Stochastic version with a large difference between which bandit pays out of two choices"""
    def __init__(
        self,
        total_trials=200, num_blocks=5, block_len_var=10,
        prob_pool = [0.8, 0.2]
    ):
        
        # generate block lengths
        avg_block_len = int(total_trials/ num_blocks)
        block_lens = []
        for block_ind in range(num_blocks-1):
            block_lens.append(int(random.uniform(
                avg_block_len-block_len_var, avg_block_len+block_len_var)))
        block_lens.append(total_trials - np.array(block_lens).sum())
        block_lens = np.array(block_lens)

        # generate block_p_bandits
        block_p_bandits = []
        
        prob_pool_permutations = list(itertools.permutations(prob_pool))
        #print(f'prob_pool_permutations: {prob_pool_permutations}')
        permutation_ind_for_block = random.randint(0, len(prob_pool_permutations)-1)
        
        for block_ind in range(num_blocks):
            #print(f'block {block_ind}: permutaton {permutation_ind_for_block}')
            block_p_bandits.append(prob_pool_permutations[permutation_ind_for_block])
            if permutation_ind_for_block==1:    
                permutation_ind_for_block = 0
            elif  permutation_ind_for_block ==0:
                permutation_ind_for_block = 1
            else:
                raise ValueError(f'permutation index out of bound')

        # generate block_r_bandits
        block_r_bandits = [[1,1] for _ in range(num_blocks)]
        
        BlockBanditEnv.__init__(self, 
            block_lens=block_lens, block_p_bandits=block_p_bandits, block_r_bandits=block_r_bandits
        )


class BlockBandit2ArmCoupledMultipleProb(BlockBanditEnv):
    """Block-switched bandit with a pool of reward probabilities to choose from"""
    def __init__(
        self,
        total_trials=400, 
        num_blocks=10, 
        block_len_var=10,
        high_reward_prob_pool = {0.9, 0.7, 0.5}
    ):
        self.total_trials = total_trials
        self.num_blocks = num_blocks
        self.block_len_var = block_len_var
        self.high_reward_prob_pool = high_reward_prob_pool

        block_lens = self.gen_block_lens(
            total_trials=self.total_trials,
            num_blocks=self.num_blocks,
            block_len_var=self.block_len_var
        )  # array
        block_p_bandits = self.gen_block_p_bandits(
            num_blocks=self.num_blocks,
            high_reward_prob_pool=self.high_reward_prob_pool
        )  # list
        block_r_bandits = self.gen_block_r_bandits(num_blocks=self.num_blocks)

        # init the env
        BlockBanditEnv.__init__(
            self, 
            block_lens=block_lens, 
            block_p_bandits=block_p_bandits, 
            block_r_bandits=block_r_bandits
        )

    def gen_block_lens(self, total_trials, num_blocks, block_len_var):
        # generate block lengths
        avg_block_len = int(total_trials/ num_blocks)
        block_lens = []
        for block_ind in range(num_blocks-1):
            block_lens.append(int(random.uniform(
                avg_block_len-block_len_var, avg_block_len+block_len_var
            )))
        block_lens.append(total_trials - np.array(block_lens).sum())
        block_lens = np.array(block_lens)
        return block_lens

    def gen_block_p_bandits(self, num_blocks, high_reward_prob_pool):
        # generate block_p_bandits
        block_p_bandits = []
        
        curr_block_high_reward_side = random.randint(0, 1)
        curr_block_high_reward_prob = random.choice(list(high_reward_prob_pool))
        for block_ind in range(num_blocks):
            p_bandits = np.zeros(2)
            p_bandits[curr_block_high_reward_side] = curr_block_high_reward_prob
            p_bandits[1-curr_block_high_reward_side] = 1 - curr_block_high_reward_prob
            block_p_bandits.append(p_bandits.tolist())

            # block switch
            if curr_block_high_reward_side==1:    
                curr_block_high_reward_side = 0
            elif  curr_block_high_reward_side ==0:
                curr_block_high_reward_side = 1
            
            # choose nnext block high_reward_prob
            curr_block_high_reward_prob = random.choice(
                list(high_reward_prob_pool - {curr_block_high_reward_prob})
            )
        
        return block_p_bandits

    def gen_block_r_bandits(self, num_blocks):
        # generate block_r_bandits
        block_r_bandits = [[1,1] for _ in range(num_blocks)]

        return block_r_bandits


    def reset(self, seed=None, options={}):
        """
        over-riding such that p_bandits is reset
        """
        # seed self.np_random
        # pass an integer for RHG right after the environment 
        # has been initialized and then never again
        if self.timestep != np.array([0]):
            super().reset(seed=seed)

            # reset timestep
            self.timestep = np.array([0])
            self.current_block = 0

            # reset probability
            block_lens = self.gen_block_lens(
            total_trials=self.total_trials,
            num_blocks=self.num_blocks,
            block_len_var=self.block_len_var
            )  # array
            block_p_bandits = self.gen_block_p_bandits(
                num_blocks=self.num_blocks,
                high_reward_prob_pool=self.high_reward_prob_pool
            )  # list
            block_r_bandits = self.gen_block_r_bandits(num_blocks=self.num_blocks)

            self.block_lens = block_lens
            self.block_p_bandits = block_p_bandits
            self.block_r_bandits = block_r_bandits

        observation = self._get_obs()
        info = self._get_info()

        return observation, info


class BlockBandit2ArmCoupledEasyMixedBlockLength(BlockBanditEnv):
    """Stochastic version with a large difference between which bandit pays out of two choices"""
    def __init__(
        self,
        total_trials=400, num_blocks_pool=[5, 10], block_len_var=10,
        prob_pool = [0.8, 0.2]
    ):
        
        # generate block lengths
        num_blocks = random.choice(num_blocks_pool)
        avg_block_len = int(total_trials/ num_blocks)
        block_lens = []
        for block_ind in range(num_blocks-1):
            block_lens.append(int(random.uniform(
                avg_block_len-block_len_var, avg_block_len+block_len_var)))
        block_lens.append(total_trials - np.array(block_lens).sum())
        block_lens = np.array(block_lens)

        # generate block_p_bandits
        block_p_bandits = []
        
        prob_pool_permutations = list(itertools.permutations(prob_pool))
        #print(f'prob_pool_permutations: {prob_pool_permutations}')
        permutation_ind_for_block = random.randint(0, len(prob_pool_permutations)-1)
        
        for block_ind in range(num_blocks):
            #print(f'block {block_ind}: permutaton {permutation_ind_for_block}')
            block_p_bandits.append(prob_pool_permutations[permutation_ind_for_block])
            if permutation_ind_for_block==1:    
                permutation_ind_for_block = 0
            elif  permutation_ind_for_block ==0:
                permutation_ind_for_block = 1
            else:
                raise ValueError(f'permutation index out of bound')

        # generate block_r_bandits
        block_r_bandits = [[1,1] for _ in range(num_blocks)]
        
        BlockBanditEnv.__init__(self, 
            block_lens=block_lens, block_p_bandits=block_p_bandits, block_r_bandits=block_r_bandits
        )


class BlockBandit2ArmMixedCoupledMultipleProbAndIndependent(BlockBanditEnv):
    """
    Block-switching bandit with either coupled or independent reward
    """

    def __init__(
        self,
        total_trials=400, num_blocks=10, block_len_var=10,
        high_reward_prob_pool = {0.9, 0.7, 0.5}
    ):
        
        # generate block lengths
        avg_block_len = int(total_trials/ num_blocks)
        block_lens = []
        for block_ind in range(num_blocks-1):
            block_lens.append(int(random.uniform(
                avg_block_len-block_len_var, avg_block_len+block_len_var)))
        block_lens.append(total_trials - np.array(block_lens).sum())
        block_lens = np.array(block_lens)

        # session type: 0 - coupled; 1 - independent
        session_type = int(random.random() >= 0.5)

        # generate block_p_bandits
        block_p_bandits = []
        curr_block_high_reward_side = random.randint(0, 1)
        curr_block_high_reward_prob = random.choice(list(high_reward_prob_pool))

        for block_ind in range(num_blocks):
            p_bandits = np.zeros(2)
            if session_type == 0:  # coupled reward session
                p_bandits[curr_block_high_reward_side] = curr_block_high_reward_prob
                p_bandits[1-curr_block_high_reward_side] = 1 - curr_block_high_reward_prob
            elif session_type == 1:  # indenpendent reward session
                p_independent_bandits = [random.random() for _ in range(2)]
                p_bandits[curr_block_high_reward_side] = max(p_independent_bandits)
                p_bandits[1-curr_block_high_reward_side] = min(p_independent_bandits)
            else:
                raise ValueError(f'session_type can only be 0 or 1, got {session_type}')
            block_p_bandits.append(p_bandits.tolist())

            # block switch
            if curr_block_high_reward_side == 1:    
                curr_block_high_reward_side = 0
            elif  curr_block_high_reward_side == 0:
                curr_block_high_reward_side = 1
            
            # choose the next block high_reward_prob for coupled block
            if session_type == 0:
                curr_block_high_reward_prob = random.choice(
                    list(high_reward_prob_pool - {curr_block_high_reward_prob})
                )

        # generate block_r_bandits
        block_r_bandits = [[1,1] for _ in range(num_blocks)]
        
        BlockBanditEnv.__init__(
            self, 
            block_lens=block_lens, 
            block_p_bandits=block_p_bandits, 
            block_r_bandits=block_r_bandits
        )


######################################################
# Block-switch bandit with baiting
class BaitedBlockBanditEnv(gym.Env):
    """
    Bandit environment base with change of reward in block structure, time as state.
    With baiting.
    ---
    p_bandits:
        A list of reward probabilities for each bandit
    r_bandits:
        A list of either rewards (if number) or means and standard deviations (if list)
        of the payout for each bandit
    info:
        Info about the environment that the agents is not supposed to know. 
        For instance, info can reveal the index of the optimal arm, 
        or the value of prior parameter.
        Can be useful to evaluate the agent's perfomance
    """
    def __init__(self, 
        block_lens, block_p_bandits, block_r_bandits
    ):

        self.block_lens = block_lens
        self.block_p_bandits = block_p_bandits
        self.block_r_bandits = block_r_bandits
        self.max_episode_steps = np.sum(self.block_lens)

        # state space
        self.observation_space = spaces.Dict({
            "timestep": spaces.Box(low=0, high=self.max_episode_steps, dtype=np.int64),  # timestep
            "rewards_at_each_arm": spaces.MultiBinary(len(self.block_p_bandits[0]))   # reward_assignment
        })
        
        # action space
        self.k_bandits = len(self.block_p_bandits[0])
        self.action_space = spaces.Discrete(self.k_bandits)
        self.rewards_at_each_arm = np.zeros(self.k_bandits, dtype=np.int8)

    def _seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def _get_obs(self):
        return {
            "timestep": self.timestep,
            "rewards_at_each_arm": self.rewards_at_each_arm
        }
    
    def _get_info(self):
        return {
            "block_lens": self.block_lens,
            "reward_prob": self.block_p_bandits,
            "reward_size": self.block_r_bandits,
            "timestep": self.timestep,
            "current_block": self.current_block,
            "current_reward_prob": self.block_p_bandits[self.current_block],
            "rewards_at_each_arm": self.rewards_at_each_arm
        }

    def reset(self, seed=None, options={}):
        """
        The reset method will be called to initiate a new episode. 
        You may assume that the step method will not be called before reset has been called. 
        Moreover, reset should be called whenever a done signal has been issued.
        This should *NOT* automatically reset the task! Resetting the task is 
        handled in the wrapper.
        """
        # seed self.np_random
        # pass an integer for RHG right after the environment has been initialized 
        # and then never again
        super().reset(seed=seed)

        # reset timestep
        self.timestep = np.array([0])
        self.current_block = 0

        # compute reward
        p_bandits = self.block_p_bandits[0]
        # update reward_assignment
        for bandit, reward_at_each_arm in enumerate(self.rewards_at_each_arm):
            if reward_at_each_arm == 0:
                if random.uniform(0, 1) < p_bandits[bandit]:
                    self.rewards_at_each_arm[bandit] = 1
        
        observation = self._get_obs()
        info = self._get_info()

        return observation, info

    def step(self, action):
        """
        Execute one step in the environment.
        Should return: (observation, reward, terminated, truncated, info)
        If terminated or truncated is true, the user needs to call reset().
        """
        # action should be type integer in [0, k_bandits-1]
        assert self.action_space.contains(action)

        #print(f'step: {self.step_count}, block {self.current_block}')
        
        # state transition
        self.timestep += 1

        # An episode is done iff max_episode_steps is reached
        terminated = bool((self.timestep >= self.max_episode_steps))

        # compute reward
        r_bandits = self.block_r_bandits[self.current_block]
        # check if action result in fetching reward
        if self.rewards_at_each_arm[action] == 1:
            reward = r_bandits[action]
            self.rewards_at_each_arm[action] = 0
        else:
            reward = 0

        # update reward_assignment: baiting
        p_bandits = self.block_p_bandits[self.current_block]
        for bandit, reward_at_each_arm in enumerate(self.rewards_at_each_arm):
            if reward_at_each_arm == 0:
                if random.uniform(0, 1) < p_bandits[bandit]:
                    self.rewards_at_each_arm[bandit] = 1

        # update current block
        if not terminated:
            if self.timestep >= np.array(self.block_lens[:self.current_block+1]).sum():
                self.current_block += 1
        
        observation = self._get_obs()
        info = self._get_info()
        
        return observation, reward, terminated, False, info


class BaitedBlockBandit2ArmCoupledEasy(BaitedBlockBanditEnv):
    """
    Stochastic version with a large difference between which 
    bandit pays out of two choices.
    With Baiting
    """
    def __init__(
        self,
        total_trials=400, num_blocks=10, block_len_var=10,
        prob_pool = [0.8, 0.2]
    ):
        
        # generate block lengths
        avg_block_len = int(total_trials/ num_blocks)
        block_lens = []
        for block_ind in range(num_blocks-1):
            block_lens.append(int(random.uniform(
                avg_block_len-block_len_var, avg_block_len+block_len_var)))
        block_lens.append(total_trials - np.array(block_lens).sum())
        block_lens = np.array(block_lens)

        # generate block_p_bandits
        block_p_bandits = []
        
        prob_pool_permutations = list(itertools.permutations(prob_pool))
        #print(f'prob_pool_permutations: {prob_pool_permutations}')
        permutation_ind_for_block = random.randint(0, len(prob_pool_permutations)-1)
        
        for block_ind in range(num_blocks):
            #print(f'block {block_ind}: permutaton {permutation_ind_for_block}')
            block_p_bandits.append(prob_pool_permutations[permutation_ind_for_block])
            if permutation_ind_for_block==1:    
                permutation_ind_for_block = 0
            elif  permutation_ind_for_block ==0:
                permutation_ind_for_block = 1
            else:
                raise ValueError(f'permutation index out of bound')

        # generate block_r_bandits
        block_r_bandits = [[1,1] for _ in range(num_blocks)]
        
        BaitedBlockBanditEnv.__init__(
            self, 
            block_lens=block_lens, 
            block_p_bandits=block_p_bandits, 
            block_r_bandits=block_r_bandits
        )


class BaitedBlockBandit2ArmCoupledMultipleProb(BaitedBlockBanditEnv):
    """
    Block-switched bandit with a pool of reward probabilities to choose from.
    With Baiting
    """
    def __init__(
        self,
        total_trials=400, num_blocks=10, block_len_var=10,
        high_reward_prob_pool = {0.9, 0.7, 0.5}
    ):
        
        # generate block lengths
        avg_block_len = int(total_trials/ num_blocks)
        block_lens = []
        for block_ind in range(num_blocks-1):
            block_lens.append(int(random.uniform(
                avg_block_len-block_len_var, avg_block_len+block_len_var)))
        block_lens.append(total_trials - np.array(block_lens).sum())
        block_lens = np.array(block_lens)

        # generate block_p_bandits
        block_p_bandits = []
        
        curr_block_high_reward_side = random.randint(0, 1)
        curr_block_high_reward_prob = random.choice(list(high_reward_prob_pool))
        for block_ind in range(num_blocks):
            p_bandits = np.zeros(2)
            p_bandits[curr_block_high_reward_side] = curr_block_high_reward_prob
            p_bandits[1-curr_block_high_reward_side] = 1 - curr_block_high_reward_prob
            block_p_bandits.append(p_bandits.tolist())

            # block switch
            if curr_block_high_reward_side==1:    
                curr_block_high_reward_side = 0
            elif  curr_block_high_reward_side ==0:
                curr_block_high_reward_side = 1
            
            # choose nnext block high_reward_prob
            curr_block_high_reward_prob = random.choice(
                list(high_reward_prob_pool - {curr_block_high_reward_prob})
            )

        # generate block_r_bandits
        block_r_bandits = [[1,1] for _ in range(num_blocks)]
        
        BaitedBlockBanditEnv.__init__(
            self, 
            block_lens=block_lens, 
            block_p_bandits=block_p_bandits, 
            block_r_bandits=block_r_bandits
        )


######################################################
# Random-walk bandit
class RandomWalkBanditEnv(gym.Env):
    """
    Bandit environment base with change of reward 
    following a random walk process, time as state.
    ---
    p_bandits:
        A list of reward probabilities for each bandit
    r_bandits:
        A list of either rewards (if number) or 
        means and standard deviations (if list)
        of the payout for each bandit
    info:
        Info about the environment that the agents is not supposed to know. 
        For instance, info can reveal the index of the optimal arm, 
        or the value of prior parameter.
        Can be useful to evaluate the agent's perfomance
    """
    def __init__(self, 
        rw_p_bandits, 
        rw_r_bandits
    ):

        self.rw_p_bandits = rw_p_bandits
        self.rw_r_bandits = rw_r_bandits
        self.max_episode_steps = len(rw_p_bandits)

        # state space
        self.observation_space = spaces.Dict({
            "timestep": spaces.Box(low=0, high=self.max_episode_steps, dtype=np.int64),  # timestep
        })
        
        # action space
        self.k_bandits = len(self.rw_p_bandits[0])
        self.action_space = spaces.Discrete(self.k_bandits)

    def _seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def _get_obs(self):
        return {
            "timestep": self.timestep
        }
    
    def _get_info(self):
        return {
            "reward_prob": self.rw_p_bandits,
            "reward_size": self.rw_r_bandits,
            "timestep": self.timestep
        }

    def reset(self, seed=None, options={}):
        """
        The reset method will be called to initiate a new episode. 
        You may assume that the step method will not be called before reset has been called. 
        Moreover, reset should be called whenever a done signal has been issued.
        This should *NOT* automatically reset the task! Resetting the task is 
        handled in the wrapper.
        """
        # seed self.np_random
        # pass an integer for RHG right after the environment has been initialized 
        # and then never again
        super().reset(seed=seed)

        # reset timestep
        self.timestep = np.array([0])
        
        observation = self._get_obs()
        info = self._get_info()

        return observation, info

    def step(self, action):
        """
        Execute one step in the environment.
        Should return: (observation, reward, terminated, truncated, info)
        If terminated or truncated is true, the user needs to call reset().
        """
        # action should be type integer in [0, k_bandits-1]
        assert self.action_space.contains(action)
        
        # state transition
        self.timestep += 1

        # An episode is done iff max_episode_steps is reached
        terminated = bool((self.timestep >= self.max_episode_steps))

        # compute reward
        reward = 0
        p_bandits = self.rw_p_bandits[self.timestep-1][0]
        r_bandits = self.rw_r_bandits[self.timestep-1][0]
        if random.uniform(0, 1) < p_bandits[action]:
            if not isinstance(r_bandits[action], list):
                reward = r_bandits[action]
            else:
                reward = random.gauss(r_bandits[action][0], r_bandits[action][1])
        
        observation = self._get_obs()
        info = self._get_info()
        
        return observation, reward, terminated, False, info


class RandomWalkBandit2ArmGaussian(RandomWalkBanditEnv):
    """
    Two-armed bandit with reward probability following
    a random walk process
    """
    def __init__(
        self,
        total_trials=400,
        random_walk_drift_rate=0.05,
        random_walk_start=0.5
    ):
        # generate rw_p_bandits
        rw_p_bandits = np.zeros((total_trials, 2))
        rw_p_bandits[0, :] = np.ones(2)*random_walk_start
        for trial_id in range(total_trials-1):
            drift = np.array([random.gauss(0, 1), random.gauss(0, 1)]) * random_walk_drift_rate
            rw_p_bandits[trial_id+1, :] = np.maximum(
                np.minimum(
                    rw_p_bandits[trial_id, :]+drift,
                    np.ones(2)
                ),
                np.zeros(2)
            )


        # generate rw_r_bandits
        rw_r_bandits = np.ones((total_trials, 2), dtype=np.int8)
        
        RandomWalkBanditEnv.__init__(
            self, 
            rw_p_bandits=rw_p_bandits,
            rw_r_bandits=rw_r_bandits
        )


class RandomWalkBandit2ArmGaussianMixedDriftRate(RandomWalkBanditEnv):
    """
    Two-armed bandit with reward probability following
    a random walk process
    """
    def __init__(
        self,
        total_trials=400,
        random_walk_drift_rate_pool=[0.2, 0.0125],
        random_walk_start=0.5
    ):
        random_walk_drift_rate = random.choice(random_walk_drift_rate_pool)
        # generate rw_p_bandits
        rw_p_bandits = np.zeros((total_trials, 2))
        rw_p_bandits[0, :] = np.ones(2)*random_walk_start
        for trial_id in range(total_trials-1):
            drift = np.array([random.gauss(0, 1), random.gauss(0, 1)]) * random_walk_drift_rate
            rw_p_bandits[trial_id+1, :] = np.maximum(
                np.minimum(
                    rw_p_bandits[trial_id, :]+drift,
                    np.ones(2)
                ),
                np.zeros(2)
            )


        # generate rw_r_bandits
        rw_r_bandits = np.ones((total_trials, 2), dtype=np.int8)
        
        RandomWalkBanditEnv.__init__(
            self, 
            rw_p_bandits=rw_p_bandits,
            rw_r_bandits=rw_r_bandits
        )


######################################################
# Time-based Block-switch bandit
class TimedBlockBanditEnv(gym.Env):
    """
    Bandit environment base with change of reward in block structure, time as state.
    ---
    p_bandits:
        A list of reward probabilities for each bandit
    r_bandits:
        A list of either rewards (if number) or means and standard deviations (if list)
        of the payout for each bandit
    info:
        Info about the environment that the agents is not supposed to know. 
        For instance, info can reveal the index of the optimal arm, 
        or the value of prior parameter.
        Can be useful to evaluate the agent's perfomance
    """
    def __init__(self, 
        block_lens, trial_lens, block_p_bandits, block_r_bandits
    ):

        self.block_lens = block_lens  # in unit of trials
        self.trial_lens = trial_lens  # in unit of timesteps
        self.block_p_bandits = block_p_bandits
        self.block_r_bandits = block_r_bandits
        self.max_episode_steps = np.sum(self.block_lens) * np.max(self.trial_lens)  # in unit of timesteps

        # state space
        self.observation_space = spaces.Dict({
            "timestep": spaces.Box(low=0, high=self.max_episode_steps, dtype=np.int64),  # timestep
            "go_cue": spaces.Discrete(2)  # go-cue
        })
        
        # action space
        self.k_bandits = len(self.block_p_bandits[0]) 
        self.action_space = spaces.Discrete(self.k_bandits + 1)  # +1 for withholding, by default the last action

    def _seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def _get_obs(self):
        return {"timestep": self.timestep, 
                "go_cue": self.go_cue}
    
    def _get_info(self):
        return {
            "block_lens": self.block_lens,
            "reward_prob": self.block_p_bandits,
            "reward_size": self.block_r_bandits,
            "timestep": self.timestep,
            "go_cue": self.go_cue,
            "current_block": self.current_block,
            "current_trial": self.current_trial,
            "current_reward_prob": self.block_p_bandits[self.current_block]
        }

    def reset(self, seed=None, options={}):
        """
        The reset method will be called to initiate a new episode. 
        You may assume that the step method will not be called before reset has been called. 
        Moreover, reset should be called whenever a done signal has been issued.
        This should *NOT* automatically reset the task! Resetting the task is 
        handled in the wrapper.
        """
        # seed self.np_random
        # pass an integer for RHG right after the environment has been initialized 
        # and then never again
        super().reset(seed=seed)

        # reset current block and trial
        self.current_block = 0
        self.current_trial = 0
        self.timestep_within_trial = 0  # timestep in current trial
    
        # reset timestep and go-cue
        self.timestep = np.array([0])  # global timestep
        self.go_cue = 0
        
        observation = self._get_obs()
        info = self._get_info()

        return observation, info

    def step(self, action):
        """
        Execute one step in the environment.
        Should return: (observation, reward, terminated, truncated, info)
        If terminated or truncated is true, the user needs to call reset().
        """
        # action should be type integer in [0, k_bandits]
        # note: last action is withhold
        assert self.action_space.contains(action)
        
        # compute reward
        reward = 0
        if self.go_cue == 1:  # only get reward if go_cue is on
            p_bandits = self.block_p_bandits[self.current_block]
            r_bandits = self.block_r_bandits[self.current_block]
            if action < self.k_bandits:  # cannot be withholding
                if random.uniform(0, 1) < p_bandits[action]:
                    if not isinstance(r_bandits[action], list):
                        reward = r_bandits[action]
                    else:
                        reward = random.gauss(r_bandits[action][0], r_bandits[action][1])
        else:
            if action != self.k_bandits:  # get punishment if not withholding
                reward = -1

        # state transition
        self.timestep += 1
        self.timestep_within_trial += 1
        # go_cue at the 2nd timestep within a trial
        if self.timestep_within_trial == 1:
            self.go_cue = 1
        else:
            self.go_cue = 0

        # An episode is done iff max_episode_steps is reached
        terminated = bool((self.timestep >= np.sum(self.trial_lens)))

        # update current block
        if not terminated:
            if self.current_trial >= np.array(self.block_lens[:self.current_block+1]).sum():
                self.current_block += 1
        # update current trial
        if not terminated:
            if self.timestep_within_trial >= self.trial_lens[self.current_trial]:
                self.current_trial += 1
                self.timestep_within_trial = 0
        
        observation = self._get_obs()
        info = self._get_info()
        
        return observation, reward, terminated, False, info


class TimedBlockBandit2ArmCoupledEasy(TimedBlockBanditEnv):
    """Stochastic version with a large difference between which bandit pays out of two choices"""
    def __init__(self, total_trials=200, num_blocks=5, block_len_var=10,
                 trial_len_range=[4,5]):
        
        # generate block lengths
        block_lens = []
        avg_block_len = int(total_trials/ num_blocks)
        for block_ind in range(num_blocks-1):
            block_lens.append(int(random.uniform(
                avg_block_len-block_len_var, avg_block_len+block_len_var)))
        block_lens.append(total_trials - np.array(block_lens).sum())
        block_lens = np.array(block_lens)
        # print(f'block_lens: {block_lens}')

        # generate trial lengths
        trial_lens = []
        for trial_ind in range(total_trials):
            trial_lens.append(random.choice(trial_len_range))
        trial_lens = np.array(trial_lens).astype(int)
        # print(f'trial_lens: {trial_lens}')

        # generate block_p_bandits
        block_p_bandits = []
        prob_pool = [0.8, 0.2]
        prob_pool_permutations = list(itertools.permutations(prob_pool))
        #print(f'prob_pool_permutations: {prob_pool_permutations}')
        permutation_ind_for_block = random.randint(0, len(prob_pool_permutations)-1)
        
        for block_ind in range(num_blocks):
            #print(f'block {block_ind}: permutaton {permutation_ind_for_block}')
            block_p_bandits.append(prob_pool_permutations[permutation_ind_for_block])
            if permutation_ind_for_block==1:    
                permutation_ind_for_block = 0
            elif  permutation_ind_for_block ==0:
                permutation_ind_for_block = 1
            else:
                raise ValueError(f'permutation index out of bound')

        # generate block_r_bandits
        block_r_bandits = [[1,1] for _ in range(num_blocks)]
        
        TimedBlockBanditEnv.__init__(self, 
            block_lens=block_lens, 
            trial_lens=trial_lens,
            block_p_bandits=block_p_bandits, 
            block_r_bandits=block_r_bandits
        )


class TimedBlockBandit2ArmCoupledMultipleProb(TimedBlockBanditEnv):
    """Stochastic version with a large difference between which bandit pays out of two choices"""
    def __init__(
        self, 
        total_trials=400, num_blocks=10, block_len_var=10,
        trial_len_range=[4,5],
        high_reward_prob_pool={0.9, 0.65, 0.5}
    ):
        
        # generate block lengths
        avg_block_len = int(total_trials/ num_blocks)
        block_lens = []
        for block_ind in range(num_blocks-1):
            block_lens.append(int(random.uniform(
                avg_block_len-block_len_var, avg_block_len+block_len_var)))
        block_lens.append(total_trials - np.array(block_lens).sum())
        block_lens = np.array(block_lens)
        # print(f'block_lens: {block_lens}')

        # generate trial lengths
        trial_lens = []
        for trial_ind in range(total_trials):
            trial_lens.append(random.choice(trial_len_range))
        trial_lens = np.array(trial_lens).astype(int)
        # print(f'trial_lens: {trial_lens}')

        # generate block_p_bandits
        block_p_bandits = []
        
        curr_block_high_reward_side = random.randint(0, 1)
        curr_block_high_reward_prob = random.choice(list(high_reward_prob_pool))
        for block_ind in range(num_blocks):
            p_bandits = np.zeros(2)
            p_bandits[curr_block_high_reward_side] = curr_block_high_reward_prob
            p_bandits[1-curr_block_high_reward_side] = 1 - curr_block_high_reward_prob
            block_p_bandits.append(p_bandits.tolist())

            # block switch
            if curr_block_high_reward_side==1:    
                curr_block_high_reward_side = 0
            elif  curr_block_high_reward_side ==0:
                curr_block_high_reward_side = 1
            
            # choose nnext block high_reward_prob
            curr_block_high_reward_prob = random.choice(
                list(high_reward_prob_pool - {curr_block_high_reward_prob})
            )

        # generate block_r_bandits
        block_r_bandits = [[1,1] for _ in range(num_blocks)]
        
        TimedBlockBanditEnv.__init__(self, 
            block_lens=block_lens, 
            trial_lens=trial_lens,
            block_p_bandits=block_p_bandits, 
            block_r_bandits=block_r_bandits
        )